All Articles

Pytorch Lightning easy classwise precision and recall logging

It was hard to find a proper classwise precision recall logger on the web. Here is a simple one for you to use:

In your experiment:

class Experiment(pl.LightningModule):
    def __init__(self, model, loss, n_classes, dual_images=False) -> None:
        super().__init__()


    ...

    ...
    
    def _calc_pr(self, outputs):
        y_true = torch.cat([x["val_true"] for x in outputs])
        y_pred = torch.cat([x["val_preds"] for x in outputs])
        class_tp = torch.zeros(self.n_classes)
        class_fn = torch.zeros(self.n_classes)
        class_total = torch.zeros(self.n_classes)
        for i in range(len(y_true)):
            if y_true[i] == y_pred[i]:
                class_tp[y_true[i]] += 1
            else:
                class_fn[y_true[i]] += 1
            class_total[y_pred[i]] += 1
        classwise_precision = class_tp / (class_tp + class_fn)
        classwise_recall = class_tp / class_total
        return classwise_precision, classwise_recall


    def validation_epoch_end(self, outputs):
        ...
        # pr, rc
        classwise_precision, classwise_recall = self._calc_pr(outputs)
        self.log_dict({f"class_{i}_precision": val for i, val in enumerate(classwise_precision.tolist())}, sync_dist=True)
        self.log_dict({f"class_{i}_recall": val for i, val in enumerate(classwise_recall.tolist())}, sync_dist=True)
        ...